import torch
import torchvision
from torch.utils.data import DataLoader, TensorDataset
import random

class Sampler:
    def __init__(
        self, device='cpu',
    ):
        self.device = device
    
    def sample(self, size=5):
        pass
    
class LoaderSampler(Sampler):
    def __init__(self, loader, device='cpu'):
        super(LoaderSampler, self).__init__(device)
        self.loader = loader
        self.it = iter(self.loader)
        
    def sample(self, size=5):
        assert size <= self.loader.batch_size
        try:
            batch, _ = next(self.it)
        except StopIteration:
            self.it = iter(self.loader)
            return self.sample(size)
        if len(batch) < size:
            return self.sample(size)
            
        return batch[:size].to(self.device)
    
class DatasetSampler(Sampler):
    def __init__(self, dataset, flag_label, batch_size,
                 num_workers=40, device='cpu'):
        super(DatasetSampler, self).__init__(device=device)
        
        self.loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
        self.flag_label = flag_label
        
        with torch.no_grad():
            self.dataset = torch.cat(
                [X for (X, y) in self.loader]
                ) if self.flag_label else torch.cat(
                [X for X in self.loader])
 
                
        
    def sample(self, batch_size=8):
        ind = random.choices(range(len(self.dataset)), k=batch_size)
        with torch.no_grad():
            batch = self.dataset[ind].clone().to(self.device).float()
            batch = torch.clamp(batch, 0, 1)
        return batch
    
    
def load_dataset(name, path, img_size=64, batch_size=64, shuffle=True, device='cuda'):
    
        
  
    # In case of using certain classe from the MNIST dataset you need to specify them by writing in the next format "MNIST_{digit}_{digit}_..._{digit}"
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((32, 32)),
        torchvision.transforms.ToTensor(),
        # torchvision.transforms.Lambda(lambda x: 2 * x - 1)
    ])

    dataset_name = name.split("_")[0]
    is_colored = False

    classes = [int(number) for number in name.split("_")[1:]]
     
    if not classes:
        classes = [i for i in range(10)]

    train_set = torchvision.datasets.MNIST(path, train=True, transform=transform, download=True)
    test_set = torchvision.datasets.MNIST(path, train=False, transform=transform, download=True)

    train_test = []

    for dataset,lbl in zip([train_set, test_set],["train","test"]):
        data = []
        labels = []
        for k in range(len(classes)):
            data.append(torch.stack(
                [dataset[i][0] for i in range(len(dataset.targets)) if dataset.targets[i] == classes[k]],
                dim=0
            ))
            labels += [k]*data[-1].shape[0]
             
            
        data = torch.cat(data, dim=0)
        data = data.reshape(-1, 1, 32, 32)
        labels = torch.tensor(labels)
        

        if is_colored:
            data = get_random_colored_images(data)
        
        if lbl == "train":
            train_test.append(TensorDataset(data[:12_211], labels[:12_211]))
        else:
            train_test.append(TensorDataset(data[:2_063], labels[:2_063]))

    train_set, test_set = train_test
    
    #train_set  = train_set[:12_211]  
    #test_set = test_set[:2_063]
    
     
    train_sampler = LoaderSampler(DataLoader(train_set, shuffle=shuffle, num_workers=8, batch_size=batch_size), device)
    test_sampler = LoaderSampler(DataLoader(test_set, shuffle=shuffle, num_workers=8, batch_size=batch_size), device)
    return train_set, test_set, train_sampler, test_sampler
